import argparse
import torch.distributed as dist
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import test2  # import test.py to get mAP after each epoch
from models import *
from utils.datasets import *
from utils.utils import *

wdir = 'weights' + os.sep  # weights dir
last = wdir + 'last.pt'
best = wdir + 'best.pt'
results_file = 'results.txt'

# Hyperparameters (results68: 59.2 mAP@0.5 yolov3-spp-416) https://github.com/ultralytics/yolov3/issues/310
hyp = {'giou': 3.54,  # giou loss gain
       'cls': 37.4,  # cls loss gain
       'cls_pw': 1.0,  # cls BCELoss positive_weight
       'obj': 49.5,  # obj loss gain (*=img_size/320 if img_size != 320)
       'obj_pw': 1.0,  # obj BCELoss positive_weight
       'iou_t': 0.225,  # iou training threshold
       'lr0': 0.01,  # initial learning rate (SGD=1E-3, Adam=9E-5)
       'lrf': -4.,  # final LambdaLR learning rate = lr0 * (10 ** lrf)
       'momentum': 0.937,  # SGD momentum
       'weight_decay': 0.000484,  # optimizer weight decay
       'fl_gamma': 0.5,  # focal loss gamma
       'hsv_h': 0.0138,  # image HSV-Hue augmentation (fraction)
       'hsv_s': 0.678,  # image HSV-Saturation augmentation (fraction)
       'hsv_v': 0.36,  # image HSV-Value augmentation (fraction)
       'degrees': 1.98,  # image rotation (+/- deg)
       'translate': 0.05,  # image translation (+/- fraction)
       'scale': 0.05,  # image scale (+/- gain)
       'shear': 0.641}  # image shear (+/- deg)

# Overwrite hyp with hyp*.txt (optional)
f = glob.glob('hyp*.txt')
if f:
    print('Using %s' % f[0])
    for k, v in zip(hyp.keys(), np.loadtxt(f[0])):
        hyp[k] = v

########################################
BIT_WIDTH=8
def get_fl(max_value, quant_bit):
    if max_value>0:
        il = math.ceil(math.log(max_value, 2))
        fl = quant_bit -1 - il
        return fl
    else:
        return 0

def quantzie(x, fl, quant_bit):
    # Saturate data
    max_data = (math.pow(2, quant_bit - 1) -1) * math.pow(2,-fl)
    min_data = -(math.pow(2, quant_bit - 1)) * math.pow(2, -fl)
    x = torch.clamp(x, min_data, max_data)
    
    # round
    x = torch.div(x,math.pow(2, -fl))
    x = torch.round(x)
    x = torch.mul(x,math.pow(2,-fl))
    return x

def quan_weight(model):
    for i, named_parameter in enumerate(model.named_parameters()):
        name, parameters=named_parameter
        new_weight_data = parameters.data.clone()
        max_value = torch.max(torch.abs(new_weight_data)).cpu().detach().numpy()
        fl=get_fl(max_value, BIT_WIDTH)
        new_weight_data=quantzie(new_weight_data, fl, BIT_WIDTH)
        parameters.data = new_weight_data

def quan_activation(output):
    max_value = torch.max(torch.abs(output)).cpu().detach().numpy()
    fl=get_fl(max_value, BIT_WIDTH)
    output=quantzie(output, fl, BIT_WIDTH)
    return output

def prune_layer(x, threshold):
    mask=x.abs()<threshold
    x=x.masked_fill_(mask,0)
    return x

# def prune_weight(model, threshold):
#     for i, named_parameter in enumerate(model.named_parameters()):
#         name, parameters=named_parameter
#         new_weight_data = parameters.data.clone()
#         new_weight_data=prune_layer(new_weight_data, threshold)
#         parameters.data = new_weight_data

def prune_weight(model, percent):
    total = 0
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            total += m.weight.data.numel()
    conv_weights = torch.zeros(total)
    index = 0
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            size = m.weight.data.numel()
            conv_weights[index:(index+size)] = m.weight.data.view(-1).abs().clone()
            index += size

    y, i = torch.sort(conv_weights)
    thre_index = int(total * percent)
    thre = y[thre_index]
    pruned = 0
    print('Pruning threshold: {}'.format(thre))
    zero_flag = False
    for k, m in enumerate(model.modules()):
        if isinstance(m, nn.Conv2d):
            weight_copy = m.weight.data.abs().clone()
            mask = weight_copy.gt(thre).float().cuda()
            pruned = pruned + mask.numel() - torch.sum(mask)
            m.weight.data.mul_(mask)
            if int(torch.sum(mask)) == 0:
                zero_flag = True
            print('layer index: {:d} \t total params: {:d} \t remaining params: {:d}'.
                format(k, mask.numel(), int(torch.sum(mask))))
    print('Total conv params: {}, Pruned conv params: {}, Pruned ratio: {}'.format(total, pruned, pruned/total))
#####################
def train():
    cfg = opt.cfg
    data = opt.data
    img_size = opt.img_size
    epochs = opt.epochs  # 500200 batches at bs 64, 117263 images = 273 epochs
    batch_size = opt.batch_size
    accumulate = opt.accumulate  # effective bs = batch_size * accumulate = 16 * 4 = 64
    weights = opt.weights  # initial training weights

    if 'pw' not in opt.arc:  # remove BCELoss positive weights
        hyp['cls_pw'] = 1.
        hyp['obj_pw'] = 1.

    # Initialize
    init_seeds()

    # Configure run
    data_dict = parse_data_cfg(data)
    train_path = data_dict['train']
    test_path = data_dict['valid']
    nc = int(data_dict['classes'])  # number of classes

    # Remove previous results
    for f in glob.glob('*_batch*.jpg') + glob.glob(results_file):
        os.remove(f)

    # Initialize model
    model = Darknet(cfg, arc=opt.arc).to(device)

    # Optimizer
    pg0, pg1 = [], []  # optimizer parameter groups
    for k, v in dict(model.named_parameters()).items():
        if 'Conv2d.weight' in k:
            pg1 += [v]  # parameter group 1 (apply weight_decay)
        else:
            pg0 += [v]  # parameter group 0

    if opt.adam:
        optimizer = optim.Adam(pg0, lr=hyp['lr0'])
    else:
        optimizer = optim.SGD(pg0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)
    optimizer.add_param_group({'params': pg1, 'weight_decay': hyp['weight_decay']})  # add pg1 with weight_decay
    

    cutoff = -1  # backbone reaches to cutoff layer
    start_epoch = 0
    best_fitness = float('inf')
    attempt_download(weights)
    if weights.endswith('.pt'):  # pytorch format
        # possible weights are '*.pt', 'yolov3-spp.pt', 'yolov3-tiny.pt' etc.
        chkpt = torch.load(weights, map_location=device)

        try:
            chkpt['model'] = {k: v for k, v in chkpt['model'].items() if model.state_dict()[k].numel() == v.numel()}
            model.load_state_dict(chkpt['model'], strict=False)
        except KeyError as e:
            s = "%s is not compatible with %s. Specify --weights '' or specify a --cfg compatible with %s. " % (opt.weights, opt.cfg, opt.weights)
            raise KeyError(s) from e

        # load optimizer
        if chkpt['optimizer'] is not None:
            optimizer.load_state_dict(chkpt['optimizer'])
            best_fitness = chkpt['best_fitness']

        # load results
        if chkpt.get('training_results') is not None:
            with open(results_file, 'w') as file:
                file.write(chkpt['training_results'])  # write results.txt

        start_epoch = chkpt['epoch'] + 1
        del chkpt

    elif len(weights) > 0:  # darknet format
        # possible weights are '*.weights', 'yolov3-tiny.conv.15',  'darknet53.conv.74' etc.
        cutoff = load_darknet_weights(model, weights)

        for p in model.parameters():
            p.requires_grad = False
    
    if opt.quan or opt.finetune:
        optimizer = optim.SGD(pg0, lr=0.0005, momentum=0.9, nesterov=True)
    
    del pg0, pg1
    
    if opt.quan or opt.finetune:
        scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[start_epoch+40], gamma=0.1)
        scheduler.last_epoch = start_epoch - 1
    else:
        scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[round(opt.epochs * x) for x in [0.8, 0.9]], gamma=0.1)
        scheduler.last_epoch = start_epoch - 1
        
    # Dataset
    dataset = LoadImagesAndLabels(train_path, img_size, batch_size,
                                  augment=True,
                                  hyp=hyp,  # augmentation hyperparameters
                                  rect=opt.rect,  # rectangular training
                                  image_weights=False,
                                  cache_labels=True,
                                  cache_images=opt.cache_images )

    # Dataloader
    batch_size = min(batch_size, len(dataset))
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=batch_size,
                                             num_workers=nw,
                                             shuffle=not opt.rect,  # Shuffle=True unless rectangular training is used
                                             pin_memory=True,
                                             collate_fn=dataset.collate_fn)

    # Test Dataloader
    testloader = torch.utils.data.DataLoader(LoadImagesAndLabels(test_path, opt.img_size, batch_size,
                                                                     hyp=hyp,
                                                                     rect=True,
                                                                     cache_labels=True,
                                                                     cache_images=opt.cache_images),
                                                 batch_size=batch_size ,
                                                 num_workers=nw,
                                                 pin_memory=True,
                                                 collate_fn=dataset.collate_fn)

    # Start training
    nb = len(dataloader)
    model.nc = nc  # attach number of classes to model
    model.arc = opt.arc  # attach yolo architecture
    model.hyp = hyp  # attach hyperparameters to model
    model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device)  # attach class weights
    print(model.hyp)
    results = (0, 0, 0, 0, 0, 0, 0)  # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification'
    t0 = time.time()
    torch_utils.model_info(model, report='summary')  # 'full' or 'summary'
    print('Using %g dataloader workers' % nw)
    print('Starting %s for %g epochs...' % ( 'training', epochs))
    print('start_epoch=', start_epoch)
    
    if opt.prune:
        prune_weight(model, opt.percent)

    if opt.quan:
        quan_weight(model)

    for epoch in range(start_epoch, epochs):  # epoch ------------------------------------------------------------------
        model.train()
        print("Learning rate = %g" % scheduler.get_last_lr()[0])
        print(('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'GIoU', 'obj', 'cls', 'total', 'targets', 'img_size'))
        
        mloss = torch.zeros(4).to(device)  # mean losses
        pbar = tqdm(enumerate(dataloader), total=nb)  # progress bar
        for i, (imgs, targets, paths, _) in pbar:  # batch -------------------------------------------------------------
            ni = i + nb * epoch  # number integrated batches (since train start)
            imgs = imgs.to(device).float() / 255.0  # uint8 to float32, 0 - 255 to 0.0 - 1.0
            targets = targets.to(device)
            
            # Plot images with bounding boxes
            if ni == 0:
                fname = 'train_batch%g.jpg' % i
                plot_images(imgs=imgs, targets=targets, paths=paths, fname=fname)
                if tb_writer:
                    tb_writer.add_image(fname, cv2.imread(fname)[:, :, ::-1], dataformats='HWC')
                ##################
                # print(imgs[0].shape)
                # plt.imshow(imgs[0].cpu().numpy().transpose(1, 2, 0))
                # plt.show()
                ##################
            # if opt.quan and (i%1000==0):
            #     quan_weight(model)

            # Run model
            pred = model(imgs)
            # Compute loss
            loss, loss_items = compute_loss(pred, targets, model)
            if not torch.isfinite(loss):
                print('WARNING: non-finite loss, ending training ', loss_items)
                return results

            

            # Scale loss by nominal batch_size of 64
            loss *= batch_size / 64
            loss.requires_grad = True
            loss.backward()

            # Accumulate gradient for x batches before optimizing
            if ni % accumulate == 0:
                optimizer.step()
                optimizer.zero_grad()

            # Print batch results
            mloss = (mloss * i + loss_items) / (i + 1)  # update mean losses

            # if (i % 20 == 0):
            #     print('loss_items = ', loss_items)
            #     print('mloss = ', mloss)
                

            mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0  # (GB)
            s = ('%10s' * 2 + '%10.3g' * 6) % ('%g/%g' % (epoch, epochs - 1), '%.3gG' % mem, *mloss, loss.detach().cpu(), img_size)
            pbar.set_description(s)
            # end batch ------------------------------------------------------------------------------------------------
        
        # Update scheduler
        scheduler.step()

        if opt.quan:
            quan_weight(model)
        # Process epoch results
        final_epoch = epoch + 1 == epochs

        # Calculate mAP
        results = test2.test(cfg,
                            data,
                            batch_size=batch_size,
                            img_size=opt.img_size,
                            model=model,
                            conf_thres=0.001 if final_epoch else 0.1,  # 0.1 for speed
                            save_json=final_epoch,
                            dataloader=testloader,
                            opt=opt)

        # Write epoch results
        with open(results_file, 'a') as f:
            f.write(s + '%10.3g' * 7 % results + '\n')  # P, R, mAP, F1, test_losses=(GIoU, obj, cls)

        # Write Tensorboard results
        if tb_writer:
            x = list(mloss) + list(results)
            titles = ['GIoU', 'Objectness', 'Classification', 'Train loss',
                      'Precision', 'Recall', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification']
            for xi, title in zip(x, titles):
                tb_writer.add_scalar(title, xi, epoch)

        # Update best mAP
        fitness = sum(results[4:])  # total loss
        if fitness < best_fitness:
            best_fitness = fitness

        # Save training results
        with open(results_file, 'r') as f:
            # Create checkpoint
            chkpt = {'epoch': epoch,
                     'best_fitness': best_fitness,
                     'training_results': f.read(),
                     'model': model.module.state_dict() if type(
                         model) is nn.parallel.DistributedDataParallel else model.state_dict(),
                     'optimizer': None if final_epoch else optimizer.state_dict()}

        # Save last checkpoint
        torch.save(chkpt, last)

        # Save best checkpoint
        if best_fitness == fitness:
            torch.save(chkpt, best)

        # Save backup every 10 epochs (optional)
        if epoch > 0 and epoch % 10 == 0:
            torch.save(chkpt, wdir + 'backup%g.pt' % epoch)

        # Delete checkpoint
        del chkpt
        # end epoch ----------------------------------------------------------------------------------------------------
    
    # end training
    plot_results()  # save as results.png
    print('%g epochs completed in %.3f hours.\n' % (epochs - start_epoch + 1, (time.time() - t0) / 3600))
    torch.cuda.empty_cache()
    return results


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--epochs', type=int, default=273)  # 500200 batches at bs 16, 117263 images = 273 epochs
    parser.add_argument('--batch-size', type=int, default=16)  # effective bs = batch_size * accumulate = 16 * 4 = 64
    parser.add_argument('--accumulate', type=int, default=4, help='batches to accumulate before optimizing')
    parser.add_argument('--cfg', type=str, default='cfg/yolov3-spp.cfg', help='*.cfg path')
    parser.add_argument('--data', type=str, default='data/voc.data', help='*.data path')
    parser.add_argument('--img-size', type=int, default=288, help='inference size (pixels)')
    parser.add_argument('--rect', action='store_true', help='rectangular training')
    parser.add_argument('--resume', action='store_true', help='resume training from last.pt')

    parser.add_argument('--cache-images', action='store_true', help='cache images for faster training')
    parser.add_argument('--weights', type=str, default='', help='initial weights')
    parser.add_argument('--arc', type=str, default='default', help='yolo architecture')  # defaultpw, uCE, uBCE
    parser.add_argument('--device', default='0', help='device id (i.e. 0 or 0,1 or cpu)')
    parser.add_argument('--adam', action='store_true', help='use adam optimizer')
    #####################
    parser.add_argument('--quan', action='store_true', help='quantize the model')
    parser.add_argument('--prune', action='store_true', help='prune the model')
    parser.add_argument('--percent', type=float, default=0.1)
    parser.add_argument('--finetune', action='store_true', help='finetune the model')
    
    opt = parser.parse_args()
    opt.weights = last if opt.resume else opt.weights
    print(opt)
    device = torch_utils.select_device(opt.device, apex=False, batch_size=opt.batch_size)

    hyp['obj'] *= opt.img_size / 320.

    tb_writer = None
    try:
        # Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/
        from torch.utils.tensorboard import SummaryWriter
        tb_writer = SummaryWriter()
    except:
        pass
    train()  # train normally